# -*- coding: utf-8 -*-
"""
Created on Mon Feb  4 07:36:00 2019

@author: Salem
"""

import numpy as np
import numpy.linalg as la


import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


#-------------------------------------
from pathlib import WindowsPath
from pathlib import Path
import os

cwd = Path(os.getcwd())


result_dir = cwd  / "Run Results" 
#-------------------------------------

import Aligning_Meshes as AM



CH_MHD_Affine_dir = result_dir / "MHD Results" / "CH_Affine" 
CH_MHD_Affine_distances =  np.load(CH_MHD_Affine_dir / "distances.npy")
D_CH_MHDA = 0.5 * (CH_MHD_Affine_distances + np.transpose(CH_MHD_Affine_distances))

CH_MHD_Original_dir = result_dir / "MHD Results" / "CH_Original" 
CH_MHD_Original_distances =  np.load(CH_MHD_Original_dir / "distances.npy")
D_CH_MHDO = 0.5 * (CH_MHD_Original_distances + np.transpose(CH_MHD_Original_distances))

Top_MHD_Affine_dir = result_dir / "MHD Results" / "Top_Affine" 
Top_MHD_Affine_distances =  np.load(Top_MHD_Affine_dir / "distances.npy")
D_Top_MHDA = 0.5 * (Top_MHD_Affine_distances + np.transpose(Top_MHD_Affine_distances))

Top_MHD_Original_dir = result_dir / "MHD Results" / "Top_Original" 
Top_MHD_Original_distances =  np.load(Top_MHD_Original_dir / "distances.npy")
D_Top_MHDO = 0.5 * (Top_MHD_Original_distances + np.transpose(Top_MHD_Original_distances))


WTRC_MHD_Affine_dir = result_dir / "MHD Results" / "WTRC_Affine" 
WTRC_MHD_Affine_distances =  np.load(WTRC_MHD_Affine_dir / "distances.npy")
D_WTRC_MHDA = 0.5 * (WTRC_MHD_Affine_distances + np.transpose(WTRC_MHD_Affine_distances))

WTRC_MHD_Original_dir = result_dir / "MHD Results" / "WTRC_Original" 
WTRC_MHD_Original_distances =  np.load(WTRC_MHD_Original_dir / "distances.npy")
D_WTRC_MHDO = 0.5 * (WTRC_MHD_Original_distances + np.transpose(WTRC_MHD_Original_distances))

#---------------------------------------------------------------------------------------------------
ICP_order = [29,30, 31, 32,33,34,35,36,37,38,39,40,41, 47, 48,44,45,46, 
             16,17,18,19,20, 21,22,27,28,6,7,8,9,10,0,1,2,3,4,5,
             11,12,13,14,15,23,24,25,26,49,50,51,52,53,42,43,54,55,56,57,58]


HD_dir = result_dir / "HD Results"  

HD_name_list = np.genfromtxt(result_dir / "gary_name_list.csv",dtype='str')

Top_HD_Affine_dir = HD_dir / "Top_Affine" 
Top_HD_Affine_distances =  np.loadtxt(Top_HD_Affine_dir / "distances.csv", delimiter=',')
# the meshes have been reordered
#Top_HD_Affine_distances[:,:] = Top_HD_Affine_distances[ICP_order,:]
#Top_HD_Affine_distances[:,:] = Top_HD_Affine_distances[:,ICP_order]
D_Top_HDA = 0.5 * (Top_HD_Affine_distances + np.transpose(Top_HD_Affine_distances))
#np.savetxt(Top_HD_Affine_dir / "distances.csv", D_Top_HDA, delimiter=',')

Top_HD_Rigid_dir = HD_dir / "Top_Rigid" 
Top_HD_Rigid_distances =  np.loadtxt(Top_HD_Rigid_dir / "distances.csv", delimiter=',')
#Top_HD_Rigid_distances[:,:] = Top_HD_Rigid_distances[ICP_order,:]
#Top_HD_Rigid_distances[:,:] = Top_HD_Rigid_distances[:,ICP_order]
D_Top_HDR = 0.5 * (Top_HD_Rigid_distances + np.transpose(Top_HD_Rigid_distances))
#np.savetxt(Top_HD_Rigid_dir / "distances.csv", D_Top_HDR, delimiter=',')

Top_HD_Uniform_dir = HD_dir / "Top_Uniform" 
Top_HD_Uniform_distances =  np.loadtxt(Top_HD_Uniform_dir / "distances.csv", delimiter=',')
#Top_HD_Uniform_distances[:,:] = Top_HD_Uniform_distances[ICP_order,:]
#Top_HD_Uniform_distances[:,:] = Top_HD_Uniform_distances[:,ICP_order]
D_Top_HDU = 0.5 * (Top_HD_Uniform_distances + np.transpose(Top_HD_Uniform_distances))
#np.savetxt(Top_HD_Uniform_dir / "distances.csv", D_Top_HDU, delimiter=',')

Top_HD_Rescale_dir = HD_dir / "Top_Rescale" 
Top_HD_Rescale_distances =  np.loadtxt(Top_HD_Rescale_dir / "distances.csv", delimiter=',')
#Top_HD_Rescale_distances[:,:] = Top_HD_Rescale_distances[ICP_order,:]
#Top_HD_Rescale_distances[:,:] = Top_HD_Rescale_distances[:,ICP_order]
D_Top_HDS = 0.5 * (Top_HD_Rescale_distances + np.transpose(Top_HD_Rescale_distances))
#np.savetxt(Top_HD_Rescale_dir / "distances.csv", D_Top_HDS, delimiter=',')




'''
Top_HD_shears =  np.loadtxt(HD_dir / "shears.csv")
Top_HD_shears[:,:] = Top_HD_shears[ICP_order,:]
Top_HD_shears[:,:] = Top_HD_shears[:,ICP_order]
'''
#------------------------------------

'''
WTRC_HD_Affine_dir = result_dir / "HD Results" / "WTRC_Affine" 
WTRC_HD_Affine_distances =  np.load(WTRC_HD_Affine_dir / "distances.npy")
WTRC_HD_Affine_distances[:,:] = WTRC_HD_Affine_distances[ICP_order,:]
WTRC_HD_Affine_distances[:,:] = WTRC_HD_Affine_distances[:,ICP_order]
D_WTRC_HDA = 0.5 * (WTRC_HD_Affine_distances + np.transpose(WTRC_HD_Affine_distances))

WTRC_HD_Rigid_dir = result_dir / "HD Results" / "WTRC_Rigid" 
WTRC_HD_Rigid_distances =  np.load(WTRC_HD_Rigid_dir / "distances.npy")
WTRC_HD_Rigid_distances[:,:] = WTRC_HD_Rigid_distances[ICP_order,:]
WTRC_HD_Rigid_distances[:,:] = WTRC_HD_Rigid_distances[:,ICP_order]
D_WTRC_HDR = 0.5 * (WTRC_HD_Rigid_distances + np.transpose(WTRC_HD_Rigid_distances))

WTRC_HD_Uniform_dir = result_dir / "HD Results" / "WTRC_Uniform" 
WTRC_HD_Uniform_distances =  np.load(WTRC_HD_Uniform_dir / "distances.npy")
WTRC_HD_Uniform_distances[:,:] = WTRC_HD_Uniform_distances[ICP_order,:]
WTRC_HD_Uniform_distances[:,:] = WTRC_HD_Uniform_distances[:,ICP_order]
D_WTRC_HDU = 0.5 * (WTRC_HD_Uniform_distances + np.transpose(WTRC_HD_Uniform_distances))

WTRC_HD_Rescale_dir = result_dir / "HD Results" / "WTRC_Rescale" 
WTRC_HD_Rescale_distances =  np.load(WTRC_HD_Rescale_dir / "distances.npy")
WTRC_HD_Rescale_distances[:,:] = WTRC_HD_Rescale_distances[ICP_order,:]
WTRC_HD_Rescale_distances[:,:] = WTRC_HD_Rescale_distances[:,ICP_order]
D_WTRC_HDS = 0.5 * (WTRC_HD_Rescale_distances + np.transpose(WTRC_HD_Rescale_distances))

'''


#-----------------------------------------------------------------------

CH_PH_Affine_dir = result_dir / "PH Results" / "CH_Affine" 
CH_PH_Affine_distances =  np.load(CH_PH_Affine_dir / "distances.npy")
D_CH_PHA = 0.5 * (CH_PH_Affine_distances + np.transpose(CH_PH_Affine_distances))

CH_PH_Original_dir = result_dir / "PH Results" / "CH_Original" 
CH_PH_Original_distances =  np.load(CH_PH_Original_dir / "distances.npy")
D_CH_PHO = 0.5 * (CH_PH_Original_distances + np.transpose(CH_PH_Original_distances))


WTRC_PH_Affine_dir = result_dir / "PH Results" / "WTRC_Affine" 
WTRC_PH_Affine_distances =  np.load(WTRC_PH_Affine_dir / "distances.npy")
D_WTRC_PHA = 0.5 * (WTRC_PH_Affine_distances + np.transpose(WTRC_PH_Affine_distances))


WTRC_PH_Original_dir = result_dir / "PH Results" / "WTRC_Original" 
WTRC_PH_Original_distances =  np.load(WTRC_PH_Original_dir / "distances.npy")
D_WTRC_PHO = 0.5 * (WTRC_PH_Original_distances + np.transpose(WTRC_PH_Original_distances))


Top_PH_Original_dir = result_dir / "PH Results" / "Top_Original" 
Top_PH_Original_distances =  np.loadtxt(Top_PH_Original_dir / "distances.npy", delimiter=',')
D_Top_PHO = 0.5 * (Top_PH_Original_distances + np.transpose(Top_PH_Original_distances))

Top_PH_Affine_dir = result_dir / "PH Results" / "Top_Affine" 
Top_PH_Affine_distances =  np.loadtxt(Top_PH_Affine_dir / "distances.npy", delimiter=',')
D_Top_PHA = 0.5 * (Top_PH_Affine_distances + np.transpose(Top_PH_Affine_distances))

print("MHD_CH gain:", np.max(D_CH_MHDO)/ np.max(D_CH_MHDA))
print("MHD_WTR gain:", np.max(D_WTRC_MHDO)/ np.max(D_WTRC_MHDA))
print("MHD_Top gain:", np.max(D_Top_MHDO)/ np.max(D_Top_MHDA))

print("PH_CH gain:", np.max(D_CH_PHO)/ np.max(D_CH_PHA))
print("PH_WTR gain:", np.max(D_WTRC_PHO)/ np.max(D_WTRC_PHA))


#print("HD_Top gain:", np.max(D_Top_HDR)/ np.max(D_Top_HDA))
#print("HD_WTR gain:", np.max(D_WTRC_HDR)/ np.max(D_WTRC_HDA))


#print("PH_Top gain:", np.max(D_Top_HDR)/ np.max(D_Top_HDA))

#=========================================================================================================================
    # plot the vertices as a scatter plot
#========================================================================================================================= 
def plot(vertices):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    x =vertices[:, 0]
    y =vertices[:, 1]
    z =vertices[:, 2]



    ax.scatter(x, y, z, c='r', marker='o')

    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')
    
    return
#=========================================================================================================================
#xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
#=========================================================================================================================   
    
#=========================================================================================================================
#displays a color plot of the given matrix
#=========================================================================================================================  
def plot_matrix(mat, save_name='Heat-Map.png', save_dir=result_dir):
    import seaborn as sns; sns.set()
    
   # 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 6, 6, 7, 7, 
   #                     8, 8, 8, 8, 9, 9, 9, 9, 9, 10, 10, 10, 10, 11, 11, 12, 12, 12, 13, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15
    
    labels = names
    
   
    fig, ax = plt.subplots(figsize=(11.6,10))         # Sample figsize in inches
    sns.heatmap(mat, xticklabels=labels, yticklabels=labels, cmap="winter_r", ax=ax)
    
    plt.savefig(save_dir / save_name, dpi=300)
    
    plt.show()
    
    return
#=========================================================================================================================
#xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
#=========================================================================================================================  

#=========================================================================================================================
# Applies multidimensional scaling to visualize the distance matrix
#=========================================================================================================================  
def make_MDS(distance_matrix, save_name='MDSplot.png', save_dir=result_dir):
    from sklearn.manifold import MDS
    
    
    
    model = MDS(n_components=2, dissimilarity='precomputed', random_state=1)
    out = model.fit_transform(distance_matrix)    
    
    #species = np.array([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 6, 6, 7, 7, 
     #                   8, 8, 8, 8, 9, 9, 9, 9, 9, 10, 10, 10, 10, 11, 11, 12, 12, 12, 13, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15])
    
    genera = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6])
        
    coloring = genera 
    
    fig, ax = plt.subplots(figsize=(11.6,11.6))
    
    colorize = dict(c=coloring, cmap=plt.cm.get_cmap('rainbow', 12)) 
    
    plt.scatter(out[:, 0], out[:, 1], **colorize, s=75)
    
    ax.tick_params(axis='both', which='major', labelsize=14)
    ax.tick_params(axis='both', which='minor', labelsize=14)
    plt.axis('equal')
    
    
    plt.savefig(save_dir / save_name, dpi=300)
    plt.show()
    
    
    return
#=========================================================================================================================
#xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
#========================================================================================================================= 
 
    
#=========================================================================================================================
# Applies multidimensional scaling to visualize the distance matrix
#=========================================================================================================================  
def get_MDS_coords(distance_matrix, save_dir, save_name='MDScoords.npy', n_components=2):
    from sklearn.manifold import MDS
    
    
    
    model = MDS(n_components=n_components, dissimilarity='precomputed', random_state=1)
    out = model.fit_transform(distance_matrix)    
 
    np.savetxt(save_dir / save_name, out, delimiter=',')
    
    return
#=========================================================================================================================
#xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
#========================================================================================================================= 

#=========================================================================================================================
# Applies multidimensional scaling to visualize the distance matrix
#=========================================================================================================================  
def make_species_legend():
   
    species = np.array([1,2, 3,4, 5,6, 7,8, 9,10,11,12,13,14,15 ]) - 1
    #species = np.array([1,2, 3,4])
    
    fig, ax = plt.subplots(figsize=(5,10))
    
    colorize = dict(c=species, cmap=plt.cm.get_cmap('rainbow', 16)) 
    
    x = 2*np.ones(species.shape[0]); y = (2.5*np.arange(species.shape[0])/species.shape[0]) - 1.5
    plt.axis('off')
    plt.scatter(x, y, s=100, **colorize)
    
    labels = ['  G. fuliginosa', '  G. fortis','  G. magnirostris', '  G. conirostris', '  G. scandens', '  G. septentrionalist', 
              '  G. difficilis','  C. palidus', '  C. parvulus', '  C. psittacula', '  P. inornata',  '  C. olivacea', '  C. fusca', 
              '  P. crassirostris', '  T. bicolor']
    
    #labels = [' Geospiza', ' Camarhynchus',  ' Certhidea', ' Platyspiza']
    

    for i, txt in enumerate(labels):
        plt.annotate(txt, (x[i], y[i]), fontsize=16, fontstyle='italic')
    
    plt.savefig(result_dir / 'species_legend.png', dpi=300)
    plt.show()
    
    
    
    return
#=========================================================================================================================
#xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
#========================================================================================================================= 

#=========================================================================================================================
# takes the transformations from a directory and visualizes the mean and std transformation
#=========================================================================================================================  
def visualize_transformations(save_dir):
    import seaborn as sns; sns.set()
    
    transformations = np.load(save_dir / "transformations.npy")
    
    labels = ['X', 'Y','Z']
    
    t = np.mean(np.mean(transformations, axis=0), axis=0)[:3,:3]
    
    print(t)
    
    fig, ax = plt.subplots(figsize=(3.3,3))         # Sample figsize in inches
    sns.heatmap(t, xticklabels=labels, yticklabels=labels, cmap="winter_r", ax=ax)
    
    plt.savefig(save_dir / "mean-mean-transformation", dpi=300)
    
    plt.show()
    
    t = np.mean(np.std(transformations, axis=0), axis=0)[:3,:3]
    
    print(t)
    
    fig, ax = plt.subplots(figsize=(3.3,3))         # Sample figsize in inches
    sns.heatmap(t, xticklabels=labels, yticklabels=labels, cmap="winter_r", ax=ax)
    
    plt.savefig(save_dir / "mean-std-transformation", dpi=300)
    
    plt.show()
    
    return
#=========================================================================================================================
#xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
#=========================================================================================================================  


#=========================================================================================================================
# compare the intraspecies distance to the interspecies distance
#=========================================================================================================================  
def inter_vs_intra(distances, species_lists = None, num_species = 33):
    
    if species_lists is None:
        species_lists = np.array([[0,1,2], [3,4,5], [6,7,8], [9,10,11], [12,13], [14,15], 
                                  [16,17,18], [19, 20,21], [22,23,24],  [25,26,27], [28, 29,], [30, 31,32]])
    
    inter_distances = []
    intra_distances = []
    
    for index_list in species_lists:
        for index1 in index_list:
            for index2 in range(num_species):
                if (index2 in index_list) and (index1 != index2):
                    inter_distances.append(distances[index1, index2])
                elif (index2 not in index_list):
                    intra_distances.append(distances[index1, index2])
                    
                    
      
                  
    return  np.mean(intra_distances)/ np.mean(inter_distances)
#=========================================================================================================================
#xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
#=========================================================================================================================  

#=========================================================================================================================
# finds the asymetrics part of the matrix, divided by the symemtric part. Zero if symmetric element is zero
#=========================================================================================================================  
def asymetric_part(mat):
    
    sym = 0.5*(mat + mat.transpose())
    
    asym = np.abs(mat - mat.transpose())
    
    rel = np.divide(asym, sym, out=np.zeros_like(asym), where=sym!=0)
    
    plot_matrix(rel)
    
    return rel.max()
#=========================================================================================================================
#xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
#=========================================================================================================================  
    
NUM_BEAKS=33
#=========================================================================================================================
# finds the principal transformation using principal component analysis
#========================================================================================================================= 
def principal_trans(load_dir=Top_MHD_Affine_dir, isDirectional=True, to_index=0, num_beaks=NUM_BEAKS):
    '''
    the to_index is the index of the fixed species that the tranformations work on. 
    Transformations act to transform one mesh to the other. Here we fix one of those.
    '''
    from sklearn.decomposition import PCA
    
    if(isDirectional):
        
        
        t = np.load(load_dir / "transformations.npy")[:,0]
        
        T = np.zeros((num_beaks, 16))
        
        for i in range(num_beaks):
            
            #to make sure the axes are not reflected
            #refl = np.diag(np.sign(np.diag(t)))
            T[i] = t[i].flatten()
           
        
        meanT = np.mean(T, axis=0)
        
        pca = PCA(n_components=1)
        pca.fit(T - meanT)
        
        print("explained variance: ", pca.explained_variance_)
        
        T_pca = pca.transform(T - meanT)
        
        T_new = pca.inverse_transform(T_pca)
        
        return T_new.reshape((num_beaks,4,4)), meanT
        
    
    t = np.load(load_dir / "transformations.npy")#[:,0]
    
    T = np.zeros((num_beaks*num_beaks, 16))
    
    for i in range(num_beaks):
        for j in range(num_beaks):
            #print(i,j)
            T[i*num_beaks + j] = t[i,j].flatten()
       
    
    meanT = np.mean(T, axis=0)
        
    pca = PCA(n_components=1)
    pca.fit(T - meanT)
        
    print("explained variance: ", pca.explained_variance_)
        
    T_pca = pca.transform(T - meanT)
    
    T_new = pca.inverse_transform(T_pca)
    
    return T_new.reshape((num_beaks*num_beaks,4,4)), meanT
#=========================================================================================================================
#xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
#=========================================================================================================================      

#c = np.divide(a, b, out=np.zeros_like(a), where=b!=0)

#names of beaks
beak1 =  "G.FuliginosaA"
beak2 =  "G.FuliginosaB"
beak3 =  "G.FuliginosaC"
beak4 =  "G.FuliginosaD"
beak5 =  "G.FuliginosaE"
beak6 = "G.FuliginosaF"
beak7 =  "G.FortisA"
beak8 =  "G.FortisB"
beak9 =  "G.FortisC"
beak10 =  "G.FortisD"
beak11 =  "G.FortisE"
beak12 = "G.MagnirostrisA"
beak13 = "G.MagnirostrisB"
beak14 = "G.MagnirostrisC"
beak15 = "G.MagnirostrisD"
beak16 = "G.MagnirostrisE"
beak17 =  "G.ConirostrisA"
beak18 =  "G.ConirostrisB"
beak19 =  "G.ConirostrisC"
beak20 =  "G.ConirostrisD"
beak21 =  "G.ConirostrisE"
beak22 =  "G.ConirostrisF"
beak23 =  "G.ConirostrisG"
beak24 =  "G.ScandensA"
beak25 =  "G.ScandensB"
beak26 =  "G.SeptentrionalistA"
beak27 =  "G.SeptentrionalistB"
beak28 =  "G.DifficilisA"
beak29 =  "G.DifficilisB"
beak30 =  "C.PallidusA"
beak31 =  "C.PallidusB"
beak32 = "C.PallidusC"
beak33 = "C.PallidusD"
beak34 =  "C.ParvulusA"
beak35 =  "C.ParvulusB"
beak36 =  "C.ParvulusC"
beak37 =  "C.ParvulusD"
beak38 =  "C.ParvulusE"
beak39 =  "C.PsittaculaA"
beak40 = "C.PsittaculaB"
beak41 =  "C.PsittaculaC"
beak42 =  "C.PsittaculaD"
beak43 =  "P2.InornataA"
beak44 =  "P2.InornataB"
beak45 = "C2.OlivaceaA"
beak46 =  "C2.OlivaceaB"
beak47 =  "C2.OlivaceaC"
beak48 =  "C2.FuscaA"
beak49 =  "C2.FuscaB"
beak50 =  "P.CrassirostrisA"
beak51 =  "P.CrassirostrisB"
beak52 =  "P.CrassirostrisC"
beak53 =  "P.CrassirostrisD"
beak54 =  "P.CrassirostrisE"
beak55 =  "T.BicolorA"
beak56 =  "T.BicolorB"
beak57 =  "T.BicolorC"
beak58 =  "T.BicolorD"
beak59 =  "T.BicolorE"
names = np.array([beak1, beak2, beak3, beak4, beak5, beak6, beak7, beak8, beak9, 
                  beak10, beak11, beak12, beak13, beak14, beak15, beak16, beak17, beak18,
                  beak19, beak20, beak21, beak22, beak23, beak24, beak25, beak26, beak27,
                  beak30, beak31, beak32, beak33, beak34, beak35, beak36, beak37,
                  beak38, beak39, beak40, beak41, beak42, beak43, beak44, beak45, beak46, beak47,
                  beak48, beak49, beak28, beak29, beak50, beak51, beak52, beak53, beak54, beak55, beak56, beak57, beak58, beak59]) 




def path_leaf(path):
    import ntpath
    head, tail = ntpath.split(path)
    return tail or ntpath.basename(head)


def get_transformations(root_dir, save_dir, name='transformations'):
# Return a list of regular files only, not directories
    file_list = [f for f in root_dir.glob('**/*') if f.is_file()]
    
    transformations = np.zeros((33,33, 4, 4))
    
    for f in file_list: 
        for i, beak1 in enumerate(names):
            for j, beak2 in enumerate(names):
                if i==j:
                    transformations[i, j] = np.eye(4)
                
                elif path_leaf(f).startswith(beak1) and (beak2 in path_leaf(f)):
                    transformations[i, j] = np.loadtxt(f)
                    
    np.save(save_dir / (name + '.npy'), transformations)
    return transformations


def collapse_meshes(root_dir, transformations):
    '''
    root_dir contrains the meshes to be transformed. 
    transformations is a vector of transformation matrices
    '''
    #need to find the right mesh to transform based on the order of the name
    
    file_list = [f for f in root_dir.glob('**/*') if f.is_file()]
    
    num_beaks = len(names)
    
    for i in range(num_beaks):
        for j in range(num_beaks):
            beak1 = file_list[i]
            beak2 = names[j]
            
            if (beak2 in path_leaf(beak1)):
                transformation = transformations[j]
                matrix = np.transpose(transformation[:3, :3])
                translation =  transformation[:3, 3]
                
                print(matrix.shape)
                AM.apply_transformation(beak1, matrix, translation)
                
    return
    
    
'''   
mesh_dir = result_dir  / "MHD Results" / "CH_Affine" / "Meshes"   
    
beak1 =  mesh_dir / "G.FortisA_WTR.stl"
beak2 =  mesh_dir / "G.FortisB_WTR.stl"
beak3 =  mesh_dir / "G.FortisC_WTR.stl"
beak4 = mesh_dir / "G.MagnirostrisA_WTR.stl"
beak5 = mesh_dir / "G.MagnirostrisB_WTR.stl"
beak6 = mesh_dir / "G.MagnirostrisC_WTR.stl"
beak7 = mesh_dir /"G.FuliginosaA_WTR.stl"
beak8 = mesh_dir /"G.FuliginosaB_WTR.stl"
beak9 = mesh_dir /"G.FuliginosaC_WTR.stl"
beak10 =  mesh_dir / "G.ScandensA_WTR.stl"
beak11 =  mesh_dir / "G.ScandensB_WTR.stl"
beak12 =  mesh_dir / "G.SeptentrionalistA_WTR.stl"
beak13 =  mesh_dir / "G.SeptentrionalistB_WTR.stl"
beak14 = mesh_dir / "C.PallidusA_WTR.stl"
beak15 = mesh_dir / "C.PallidusB_WTR.stl"
beak16 = mesh_dir / "C.PallidusC_WTR.stl"
beak17 = mesh_dir / "C.PsittaculaA_WTR.stl"
beak18 = mesh_dir / "C.PsittaculaB_WTR.stl"
beak19 = mesh_dir / "C.PsittaculaC_WTR.stl"
beak20 = mesh_dir / "C.ParvulusA_WTR.stl"
beak21 = mesh_dir / "C.ParvulusB_WTR.stl"
beak22 = mesh_dir / "C.ParvulusC_WTR.stl"
beak23 = mesh_dir / "C2.FuscaA_WTR.stl"
beak24 = mesh_dir / "C2.FuscaB_WTR.stl"
beak25 = mesh_dir / "C2.OlivaceaA_WTR.stl"
beak26 = mesh_dir / "C2.OlivaceaB_WTR.stl"
beak27 = mesh_dir / "C2.OlivaceaC_WTR.stl"
beak28 =  mesh_dir / "G.ConirostrisA_WTR.stl"
beak29 =  mesh_dir / "G.ConirostrisB_WTR.stl"
beak30 =  mesh_dir / "G.ConirostrisC_WTR.stl"
beak31 =  mesh_dir / "P.CrassirostrisA_WTR.stl"
beak32 =  mesh_dir / "P.CrassirostrisB_WTR.stl"
beak33 =  mesh_dir / "P.CrassirostrisC_WTR.stl"
beaks_wtr = np.array([beak1, beak2, beak3, beak4, beak5, beak6, beak7, beak8, beak9, 
                  beak10, beak11, beak12, beak13, beak14, beak15, beak16, beak17, beak18,
                  beak19, beak20, beak21, beak22, beak23, beak24, beak25, beak26, beak27,
                  beak28, beak29, beak30, beak31, beak32, beak33]) 
        
        
beak1 =  mesh_dir / "G.FortisA.stl"
beak2 =  mesh_dir / "G.FortisB.stl"
beak3 =  mesh_dir / "G.FortisC.stl"
beak4 = mesh_dir / "G.MagnirostrisA.stl"
beak5 = mesh_dir / "G.MagnirostrisB.stl"
beak6 = mesh_dir / "G.MagnirostrisC.stl"
beak7 = mesh_dir /"G.FuliginosaA.stl"
beak8 = mesh_dir /"G.FuliginosaB.stl"
beak9 = mesh_dir /"G.FuliginosaC.stl"
beak10 =  mesh_dir / "G.ScandensA.stl"
beak11 =  mesh_dir / "G.ScandensB.stl"
beak12 =  mesh_dir / "G.SeptentrionalistA.stl"
beak13 =  mesh_dir / "G.SeptentrionalistB.stl"
beak14 = mesh_dir / "C.PallidusA.stl"
beak15 = mesh_dir / "C.PallidusB.stl"
beak16 = mesh_dir / "C.PallidusC.stl"
beak17 = mesh_dir / "C.PsittaculaA.stl"
beak18 = mesh_dir / "C.PsittaculaB.stl"
beak19 = mesh_dir / "C.PsittaculaC.stl"
beak20 = mesh_dir / "C.ParvulusA.stl"
beak21 = mesh_dir / "C.ParvulusB.stl"
beak22 = mesh_dir / "C.ParvulusC.stl"
beak23 = mesh_dir / "C2.FuscaA.stl"
beak24 = mesh_dir / "C2.FuscaB.stl"
beak25 = mesh_dir / "C2.OlivaceaA.stl"
beak26 = mesh_dir / "C2.OlivaceaB.stl"
beak27 = mesh_dir / "C2.OlivaceaC.stl"
beak28 =  mesh_dir / "G.ConirostrisA.stl"
beak29 =  mesh_dir / "G.ConirostrisB.stl"
beak30 =  mesh_dir / "G.ConirostrisC.stl"
beak31 =  mesh_dir / "P.CrassirostrisA.stl"
beak32 =  mesh_dir / "P.CrassirostrisB.stl"
beak33 =  mesh_dir / "P.CrassirostrisC.stl"
beaks_t = np.array([beak1, beak2, beak3, beak4, beak5, beak6, beak7, beak8, beak9, 
                  beak10, beak11, beak12, beak13, beak14, beak15, beak16, beak17, beak18,
                  beak19, beak20, beak21, beak22, beak23, beak24, beak25, beak26, beak27,
                  beak28, beak29, beak30, beak31, beak32, beak33]) 




matrix1 =  matrix_dir / "C.PallidusA_G.FortisA.txt"
matrix2 =  matrix_dir / "C.PallidusA_G.FortisB.txt"
matrix3 =  matrix_dir / "C.PallidusA_G.FortisC.txt"
matrix4 =  matrix_dir / "C.PallidusA_G.MagnirostrisA.txt"
matrix5 =  matrix_dir / "C.PallidusA_G.MagnirostrisB.txt"
matrix6 =  matrix_dir / "C.PallidusA_G.MagnirostrisC.txt"
matrix7 =  matrix_dir / "C.PallidusA_G.FuliginosaA.txt"
matrix8 =  matrix_dir / "C.PallidusA_G.FuliginosaB.txt"
matrix9 =  matrix_dir /  "C.PallidusA_G.FuliginosaC.txt"
matrix10 =  matrix_dir / "C.PallidusA_G.ScandensA.txt"
matrix11 =  matrix_dir / "C.PallidusA_G.ScandensB.txt"
matrix12 =  matrix_dir / "C.PallidusA_G.SeptentrionalistA.txt"
matrix13 =  matrix_dir / "C.PallidusA_G.SeptentrionalistB.txt"
matrix14 =  matrix_dir / "C.PallidusA_C.PallidusA.txt"
matrix15 =  matrix_dir / "C.PallidusA_C.PallidusB.txt"
matrix16 =  matrix_dir / "C.PallidusA_C.PallidusC.txt"
matrix17 =  matrix_dir / "C.PallidusA_C.PsittaculaA.txt"
matrix18 =  matrix_dir /  "C.PallidusA_C.PsittaculaB.txt"
matrix19 =  matrix_dir / "C.PallidusA_C.PsittaculaC.txt"
matrix20 =  matrix_dir / "C.PallidusA_C.ParvulusA.txt"
matrix21 =  matrix_dir / "C.PallidusA_C.ParvulusB.txt"
matrix22 =  matrix_dir / "C.PallidusA_C.ParvulusC.txt"
matrix23 =  matrix_dir / "C.PallidusA_C2.FuscaA.txt"
matrix24 =  matrix_dir / "C.PallidusA_C2.FuscaB.txt"
matrix25 =  matrix_dir / "C.PallidusA_C2.OlivaceaA.txt"
matrix26 =  matrix_dir / "C.PallidusA_C2.OlivaceaB.txt"
matrix27 =  matrix_dir /  "C.PallidusA_C2.OlivaceaC.txt"
matrix28 =  matrix_dir / "C.PallidusA_G.ConirostrisA.txt"
matrix29 =  matrix_dir / "C.PallidusA_G.ConirostrisB.txt"
matrix30 =  matrix_dir / "C.PallidusA_G.ConirostrisC.txt"
matrix31 =  matrix_dir / "C.PallidusA_P.CrassirostrisA.txt"
matrix32 =  matrix_dir / "C.PallidusA_P.CrassirostrisB.txt"
matrix33 =  matrix_dir / "C.PallidusA_P.CrassirostrisC.txt"
matrices = np.array([matrix1, matrix2, matrix3, matrix4, matrix5, matrix6, matrix7, matrix8, matrix9, 
                  matrix10, matrix11, matrix12, matrix13, matrix14, matrix15, matrix16, matrix17, matrix18,
                  matrix19, matrix20, matrix21, matrix22, matrix23, matrix24, matrix25, matrix26, matrix27,
                  matrix28, matrix29, matrix30, matrix31, matrix32, matrix33]) 
        
'''

'''
for i,  beak in enumerate(beaks_t):
    transformation = np.loadtxt(matrices[i])
    matrix = np.transpose(la.inv(transformation[:3, :3]))
    translation = - transformation[:3, 3]
    AM.apply_transformation(beak, matrix, translation)
'''


'''
result_dir1 = result_dir / "Run with CH Small Cut Not Aligned"
result_dir2 = result_dir / "Run with WTR Aligned"
result_dir3 = result_dir / "Run with CH Small Cut Aligned"

distances1 = np.load(result_dir1 / "distances.npy")

transformations1 = np.load(result_dir1 / "transformations.npy")


distances2 = np.load(result_dir2 / "distances.npy")
D2 = 0.5 * (distances2 + np.transpose(distances2))

transformations2 = np.load(result_dir2 / "transformations_scale_opt.npy")

distances3 = np.load(result_dir3 / "distances_affine_opt.npy")

transformations3 = np.load(result_dir3 / "transformations_affine_opt.npy")

shears1 = np.zeros_like(distances1) 
for i in range(distances1.shape[0]):
    for j in range(distances1.shape[0]):
        sl3 = transformations1[i,j]
        shears1[i,j] = max(np.abs([sl3[0,1],sl3[0,2],sl3[1,2]]))

shearsA3 = np.zeros_like(distances3) 
shearsB3 = np.zeros_like(distances3) 
for i in range(distances3.shape[0]):
    for j in range(distances1.shape[0]):
        sl3 = transformations3[i,j]
        
        shearsA3[i,j] = max(np.abs([sl3[0,1],sl3[0,2],sl3[1,2]]))
        shearsB3[i,j] = max(np.abs([sl3[0,1],sl3[0,2],sl3[1,2]]))/min(np.abs([sl3[0,0],sl3[1,1],sl3[2,2]]))
        
dets3 = np.zeros_like(distances3) 
for i in range(distances3.shape[0]):
    for j in range(distances1.shape[0]):
        sl3 = transformations3[i,j]
        dets3[i,j] = np.abs(la.det(sl3))
'''















